import os
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer


class HunyuanClip(nn.Module):
    """
        Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
        hunyuan's clip used BertModel and BertTokenizer, so we copy it.
    """
    def __init__(self, model_dir, max_length=77):
        super(HunyuanClip, self).__init__()
        
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained(os.path.join(model_dir, 'tokenizer'))
        self.text_encoder = BertModel.from_pretrained(os.path.join(model_dir, 'clip_text_encoder'))
        
    @torch.no_grad
    def forward(self, prompts, with_mask=True):
        self.device = next(self.text_encoder.parameters()).device
        text_inputs = self.tokenizer(
            prompts,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt",
        )
        prompt_embeds = self.text_encoder(
            text_inputs.input_ids.to(self.device),
            attention_mask=text_inputs.attention_mask.to(self.device) if with_mask else None,
        )
        return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output
        